from __future__ import division
import os
import time
import math
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python import debug as tfdbg
from data_loader import DataLoader
from nets import *
from utils import *

class SfMLearner(object):
    def __init__(self,opt):
        self.opt = opt

        # self.build_train_graph()
        # # sfm.build_val_graph()
        # self.collect_summaries()

    def gradient_sobel(pred):
        D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :]
        D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :]
        return D_dx, D_dy
    
    def build_train_graph(self):
        opt = self.opt
        # loader = DataLoader(opt.dataset_dir,
        #                     opt.filenames_file,
        #                     opt.batch_size,
        #                     opt.img_height,
        #                     opt.img_width,
        #                     opt.num_source,
        #                     opt.num_scales)
        with tf.name_scope("data_loading"):
            # tgt_image, src_image_stack, intrinsics = loader.load_train_batch()
            # tgt_image = self.preprocess_image(tgt_image)
            # src_image_stack = self.preprocess_image(src_image_stack)

            image_concat = tf.placeholder(tf.uint8, shape=[opt.batch_size, opt.img_height, opt.img_width,
                                                                opt.num_source * 6],
                                               name='imgs')
            intrincis = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.num_scales, 3, 3], name='intrincis')
            rel_pose = tf.placeholder(tf.float32, shape=[opt.batch_size, opt.num_source, 6])

            self.image_summary = image_concat[:,:,:,:3]

                # = loader.load_train_batch()
            image_concat_float = self.preprocess_image(image_concat)

        # with tf.name_scope("depth_prediction"):
        #     pred_disp, depth_net_endpoints = disp_net(tgt_image,
        #                                               is_training=True)
        #     pred_depth = [1./d for d in pred_disp]

        with tf.name_scope("pose_and_explainability_prediction"):
            pose_input = image_concat_float[:,:,:,:6]
            for i in range(opt.num_source-1):
                pose_input = tf.concat([pose_input,image_concat_float[:,:,:,6*(i+1):6*(i+2)]], axis=0)
            pred_poses, pred_exp_logits, pose_exp_net_endpoints = \
                pose_exp_net(pose_input,
                             do_exp=(opt.explain_reg_weight > 0),
                             is_training=True)


        # with tf.name_scope("pose_and_explainability_prediction"):
        #
        #     pose_input = image_concat_float[:,:,:,:6]
        #     pose_input = tf.concat([tf.reduce_mean(pose_input[:, :, :, :3], axis=3, keep_dims=True),
        #                            tf.reduce_mean(pose_input[:, :, :, 3:], axis=3, keep_dims=True)], axis=3)
        #     for i in range(opt.num_source-1):
        #         pose_input_next = image_concat_float[:,:,:,6*(i+1):6*(i+2)]
        #         pose_input_next = tf.concat([tf.reduce_mean(pose_input_next[:, :, :, :3], axis=3, keep_dims=True),
        #                                      tf.reduce_mean(pose_input_next[:, :, :, 3:], axis=3, keep_dims=True)], axis=3)
        #         pose_input = tf.concat([pose_input,pose_input_next], axis=0)
        #     pred_poses, pred_exp_logits, pose_exp_net_endpoints = \
        #         pose_exp_net(pose_input,
        #                      do_exp=(opt.explain_reg_weight > 0),
        #                      is_training=True)

        with tf.name_scope("compute_loss"):
            pixel_loss = 0
            exp_loss = 0
            smooth_loss = 0
            pose_loss = 0
            tran_loss = 0
            rot_ross = 0
            tgt_image_all = []
            src_image_stack_all = []
            proj_image_stack_all = []
            proj_error_stack_all = []
            exp_mask_stack_all = []
            bs = opt.batch_size
            for i in range(opt.num_source):
                pred_pose = pred_poses[bs*i:bs*(i+1)]
                pred_pose = tf.squeeze(pred_pose,axis=1)
                gt_rel_pose = rel_pose[:,i,:]
                # gt_rel_pose = tf.string_to_number(gt_rel_pose, out_type=tf.float32)

                tran_error = pred_pose[:, :3] - gt_rel_pose[:, :3]
                rot_error = pred_pose[:, 3:] - gt_rel_pose[:, 3:]

                # tran_error_max = tf.reduce_max(tran_error)
                # tran_error_min = tf.reduce_min(tran_error)
                #
                # rot_error_max = tf.reduce_max(rot_error)
                # rot_error_min = tf.reduce_min(rot_error)
                #
                #
                # tran_error_normalize = (tran_error - tran_error_min) / (tran_error_max - tran_error_min)
                # rot_error_normalize = (rot_error - rot_error_min) / (rot_error_max - rot_error_min)

                # alpha = tran_error / rot_error
                pose_loss_tran = tf.reduce_mean(tf.square(tran_error))

                pose_loss_rot = tf.reduce_mean(tf.square(rot_error)) * 50

                # tran_loss += pose_loss_tran
                # rot_ross += pose_loss_rot
                pose_loss += pose_loss_tran + pose_loss_rot
            # for s in range(opt.num_scales):
            #     if opt.explain_reg_weight > 0:
            #         # Construct a reference explainability mask (i.e. all
            #         # pixels are explainable)
            #         ref_exp_mask = self.get_reference_explain_mask(s)
            #     # Scale the source and target images for computing loss at the
            #     # according scale.
            #     curr_tgt_image = tf.image.resize_area(tgt_image,
            #         [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))])
            #     curr_src_image_stack = tf.image.resize_area(src_image_stack,
            #         [int(opt.img_height/(2**s)), int(opt.img_width/(2**s))])
            #
            #     if opt.smooth_weight > 0:
            #         smooth_loss += opt.smooth_weight/(2**s) * \
            #             self.compute_smooth_loss(pred_disp[s])
            #
            #     for i in range(opt.num_source):
            #         # Inverse warp the source image to the target image frame
            #         curr_proj_image = projective_inverse_warp(
            #             curr_src_image_stack[:,:,:,3*i:3*(i+1)],
            #             tf.squeeze(pred_depth[s], axis=3),
            #             pred_poses[:,i,:],
            #             intrinsics[:,s,:,:])
            #         curr_proj_error = tf.abs(curr_proj_image - curr_tgt_image)
            #         # Cross-entropy loss as regularization for the
            #         # explainability prediction
            #         if opt.explain_reg_weight > 0:
            #             curr_exp_logits = tf.slice(pred_exp_logits[s],
            #                                        [0, 0, 0, i*2],
            #                                        [-1, -1, -1, 2])
            #             exp_loss += opt.explain_reg_weight * \
            #                 self.compute_exp_reg_loss(curr_exp_logits,
            #                                           ref_exp_mask)
            #             curr_exp = tf.nn.softmax(curr_exp_logits)
            #         # Photo-consistency loss weighted by explainability
            #         if opt.explain_reg_weight > 0:
            #             pixel_loss += tf.reduce_mean(curr_proj_error * \
            #                 tf.expand_dims(curr_exp[:,:,:,1], -1))
            #         else:
            #             pixel_loss += tf.reduce_mean(curr_proj_error)
            #         # Prepare images for tensorboard summaries
            #         if i == 0:
            #             proj_image_stack = curr_proj_image
            #             proj_error_stack = curr_proj_error
            #             if opt.explain_reg_weight > 0:
            #                 exp_mask_stack = tf.expand_dims(curr_exp[:,:,:,1], -1)
            #         else:
            #             proj_image_stack = tf.concat([proj_image_stack,
            #                                           curr_proj_image], axis=3)
            #             proj_error_stack = tf.concat([proj_error_stack,
            #                                           curr_proj_error], axis=3)
            #             if opt.explain_reg_weight > 0:
            #                 exp_mask_stack = tf.concat([exp_mask_stack,
            #                     tf.expand_dims(curr_exp[:,:,:,1], -1)], axis=3)
            #     tgt_image_all.append(curr_tgt_image)
            #     src_image_stack_all.append(curr_src_image_stack)
            #     proj_image_stack_all.append(proj_image_stack)
            #     proj_error_stack_all.append(proj_error_stack)
            #     if opt.explain_reg_weight > 0:
            #         exp_mask_stack_all.append(exp_mask_stack)
            # total_loss = pixel_loss + smooth_loss + exp_loss
            total_loss = pose_loss

        with tf.name_scope("train_op"):
            train_vars = [var for var in tf.trainable_variables()]
            optim = tf.train.AdamOptimizer(opt.learning_rate, opt.beta1)
            # self.grads_and_vars = optim.compute_gradients(total_loss, 
            #                                               var_list=train_vars)
            # self.train_op = optim.apply_gradients(self.grads_and_vars)
            self.train_op = slim.learning.create_train_op(total_loss, optim)
            self.global_step = tf.Variable(0, 
                                           name='global_step', 
                                           trainable=False)
            self.incr_global_step = tf.assign(self.global_step, 
                                              self.global_step+1)

        # Collect tensors that are useful later (e.g. tf summary)
        # self.pred_depth = pred_depth
        self.pred_poses = pred_poses
        self.val_loss = pose_loss
        # self.steps_per_epoch = loader.steps_per_epoch
        self.total_loss = total_loss
        self.pose_loss = pose_loss
        self.pixel_loss = pixel_loss
        self.exp_loss = exp_loss
        self.smooth_loss = smooth_loss
        self.tgt_image_all = tgt_image_all
        self.src_image_stack_all = src_image_stack_all
        self.proj_image_stack_all = proj_image_stack_all
        self.proj_error_stack_all = proj_error_stack_all
        self.exp_mask_stack_all = exp_mask_stack_all

        self.image_concat = image_concat
        self.intrincis = intrincis
        self.rel_pose = rel_pose


    def get_reference_explain_mask(self, downscaling):
        opt = self.opt
        tmp = np.array([0,1])
        ref_exp_mask = np.tile(tmp, 
                               (opt.batch_size, 
                                int(opt.img_height/(2**downscaling)), 
                                int(opt.img_width/(2**downscaling)), 
                                1))
        ref_exp_mask = tf.constant(ref_exp_mask, dtype=tf.float32)
        return ref_exp_mask

    def compute_exp_reg_loss(self, pred, ref):
        l = tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.reshape(ref, [-1, 2]),
            logits=tf.reshape(pred, [-1, 2]))
        return tf.reduce_mean(l)

    def compute_smooth_loss(self, pred_disp):
        def gradient(pred):
            D_dy = pred[:, 1:, :, :] - pred[:, :-1, :, :]
            D_dx = pred[:, :, 1:, :] - pred[:, :, :-1, :]
            return D_dx, D_dy
        dx, dy = gradient(pred_disp)
        dx2, dxdy = gradient(dx)
        dydx, dy2 = gradient(dy)
        return tf.reduce_mean(tf.abs(dx2)) + \
               tf.reduce_mean(tf.abs(dxdy)) + \
               tf.reduce_mean(tf.abs(dydx)) + \
               tf.reduce_mean(tf.abs(dy2))

    def collect_summaries(self):
        opt = self.opt
        tf.summary.scalar("total_loss", self.total_loss)
        tf.summary.image("image_summary ",self.image_summary )
        # tf.summary.scalar("val_loss", self.val_loss)
        # tf.summary.scalar("pixel_loss", self.pixel_loss)
        # tf.summary.scalar("smooth_loss", self.smooth_loss)
        # tf.summary.scalar("exp_loss", self.exp_loss)
        # for s in range(opt.num_scales):
        #     tf.summary.histogram("scale%d_depth" % s, self.pred_depth[s])
        #     tf.summary.image('scale%d_disparity_image' % s, 1./self.pred_depth[s])
        #     tf.summary.image('scale%d_target_image' % s, \
        #                      self.deprocess_image(self.tgt_image_all[s]))
        #     for i in range(opt.num_source):
        #         if opt.explain_reg_weight > 0:
        #             tf.summary.image(
        #                 'scale%d_exp_mask_%d' % (s, i),
        #                 tf.expand_dims(self.exp_mask_stack_all[s][:,:,:,i], -1))
        #         tf.summary.image(
        #             'scale%d_source_image_%d' % (s, i),
        #             self.deprocess_image(self.src_image_stack_all[s][:, :, :, i*3:(i+1)*3]))
        #         tf.summary.image('scale%d_projected_image_%d' % (s, i),
        #             self.deprocess_image(self.proj_image_stack_all[s][:, :, :, i*3:(i+1)*3]))
        #         tf.summary.image('scale%d_proj_error_%d' % (s, i),
        #             self.deprocess_image(tf.clip_by_value(self.proj_error_stack_all[s][:,:,:,i*3:(i+1)*3] - 1, -1, 1)))
        # tf.summary.histogram("tx", self.pred_poses[:,:,0])
        # tf.summary.histogram("ty", self.pred_poses[:,:,1])
        # tf.summary.histogram("tz", self.pred_poses[:,:,2])
        # tf.summary.histogram("rx", self.pred_poses[:,:,3])
        # tf.summary.histogram("ry", self.pred_poses[:,:,4])
        # tf.summary.histogram("rz", self.pred_poses[:,:,5])
        for var in tf.trainable_variables():
            tf.summary.histogram(var.op.name + "/values", var)
        # for grad, var in self.grads_and_vars:
        #     tf.summary.histogram(var.op.name + "/gradients", grad)

    # def train(self, opt):
    #     print(opt.num_source)
    #     opt.num_source = opt.seq_length - 1
    #
    #     # TODO: currently fixed to 4
    #     opt.num_scales = 4
    #     self.opt = opt
    #     self.build_train_graph()
    #     # self.build_val_graph()
    #     self.collect_summaries()
    #     with tf.name_scope("parameter_count"):
    #         parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \
    #                                         for v in tf.trainable_variables()])
    #     self.saver = tf.train.Saver([var for var in tf.model_variables()] + \
    #                                 [self.global_step],
    #                                  max_to_keep=10)
    #     sv = tf.train.Supervisor(logdir=opt.checkpoint_dir,
    #                              save_summaries_secs=0,
    #                              saver=None)
    #     config = tf.ConfigProto()
    #     # config.gpu_options.allow_growth = True
    #     config.gpu_options.per_process_gpu_memory_fraction = 0.7
    #     with sv.managed_session(config=config) as sess:
    #         print('Trainable variables: ')
    #         for var in tf.trainable_variables():
    #             print(var.name)
    #         print("parameter_count =", sess.run(parameter_count))
    #         if opt.continue_train:
    #             if opt.init_checkpoint_file is None:
    #                 checkpoint = tf.train.latest_checkpoint(opt.checkpoint_dir)
    #             else:
    #                 checkpoint = opt.init_checkpoint_file
    #             print("Resume training from previous checkpoint: %s" % checkpoint)
    #             self.saver.restore(sess, checkpoint)
    #         start_time = time.time()
    #
    #         # sess = tfdbg.LocalCLIDebugWrapperSession(sess)
    #
    #         for step in range(1, opt.max_steps):
    #             fetches = {
    #                 "train": self.train_op,
    #                 "global_step": self.global_step,
    #                 "incr_global_step": self.incr_global_step,
    #             }
    #
    #             if step % opt.summary_freq == 0:
    #                 fetches["loss"] = self.total_loss
    #                 fetches["summary"] = sv.summary_op
    #
    #             results = sess.run(fetches)
    #             gs = results["global_step"]
    #
    #             if step % opt.summary_freq == 0:
    #                 sv.summary_writer.add_summary(results["summary"], gs)
    #                 train_epoch = math.ceil(gs / self.steps_per_epoch)
    #                 train_step = gs - (train_epoch - 1) * self.steps_per_epoch
    #                 print("Epoch: [%2d] [%5d/%5d] time: %4.4f/it loss: %.5f" \
    #                         % (train_epoch, train_step, self.steps_per_epoch, \
    #                             (time.time() - start_time)/opt.summary_freq,
    #                             results["loss"]))
    #                 start_time = time.time()
    #
    #             if step % opt.save_latest_freq == 0:
    #                 self.save(sess, opt.checkpoint_dir, 'latest')
    #
    #             if step % self.steps_per_epoch == 0:
    #                 self.save(sess, opt.checkpoint_dir, gs)
    #
    #             # #val
    #             # if step % self.steps_per_epoch == 0:
    #             #     for i in range(50):
    #             #         fetches = {
    #             #             "val_loss": self.val_loss,
    #             #             "summary": sv.summary_op
    #             #         }
    #             #         results = sess.run(fetches)
    #             #         sv.summary_writer.add_summary(results["summary"], gs)
    #             #         train_epoch = math.ceil(gs / self.steps_per_epoch)
    #             #         train_step = gs - (train_epoch - 1) * self.steps_per_epoch
    #             #         print("Epoch: [%2d] [%5d/%5d]  loss: %.3f" \
    #             #               % (train_epoch, train_step, self.steps_per_epoch,  results["val_loss"]))


    def build_depth_test_graph(self):
        input_uint8 = tf.placeholder(tf.uint8, [self.batch_size, 
                    self.img_height, self.img_width, 3], name='raw_input')
        input_mc = self.preprocess_image(input_uint8)
        with tf.name_scope("depth_prediction"):
            pred_disp, depth_net_endpoints = disp_net(
                input_mc, is_training=False)
            pred_depth = [1./disp for disp in pred_disp]
        pred_depth = pred_depth[0]
        self.inputs = input_uint8
        self.pred_depth = pred_depth
        self.depth_epts = depth_net_endpoints

    def build_pose_test_graph(self):
        input_uint8 = tf.placeholder(tf.uint8, [self.batch_size, 
            self.img_height, self.img_width * self.seq_length,  3],
            name='raw_input')

        # image_show = tf.image.convert_image_dtype(input_uint8,
        #                                           dtype=tf.float32)  # tensorflow中操作多为浮点型，而图片多为int型，故作此转化
        # image_batch = tf.expand_dims(image_show, 0)
        #
        # kernel = tf.constant([
        #     [
        #         [[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]],
        #         [[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]],
        #         [[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]]
        #     ],
        #     [
        #         [[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]],
        #         [[8., 0., 0.], [0., 8., 0.], [0., 0., 8.]],
        #         [[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]]
        #     ],
        #     [
        #         [[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]],
        #         [[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]],
        #         [[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]]
        #     ]
        # ])
        #
        # conv2d = tf.nn.conv2d(image_batch, kernel, [1, 1, 1, 1], padding='SAME')
        # activation_map = tf.minimum(tf.nn.relu(conv2d), 1)  # 激活措施加均值操作将颜色值置于（0~255）以内的区间
        #
        # encoded_image = tf.squeeze(activation_map, axis=0)
        # # print(encoded_image.shape)
        # image_seq = tf.image.convert_image_dtype(encoded_image, dtype=tf.uint8)

        input_mc = self.preprocess_image(input_uint8)
        loader = DataLoader()
        image_seq_concat = \
            loader.batch_unpack_image_sequence(
                input_mc, self.img_height, self.img_width, self.num_source)
        with tf.name_scope("pose_prediction"):
            pred_poses, _, _ = pose_exp_net(
                image_seq_concat, do_exp=False, is_training=False)
            self.inputs = input_uint8
            self.pred_poses = pred_poses

    def preprocess_image(self, image):
        # Assuming input image is uint8
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        return image * 2. -1.

    def deprocess_image(self, image):
        # Assuming input image is float32
        image = (image + 1.)/2.
        return tf.image.convert_image_dtype(image, dtype=tf.uint8)

    def setup_inference(self,
                        img_height,
                        img_width,
                        mode,
                        seq_length=3,
                        batch_size=1):
        self.img_height = img_height
        self.img_width = img_width
        self.mode = mode
        self.batch_size = batch_size
        if self.mode == 'depth':
            self.build_depth_test_graph()
        if self.mode == 'pose':
            self.seq_length = seq_length
            self.num_source = seq_length - 1
            self.build_pose_test_graph()

    def inference(self, inputs, sess, mode='depth'):
        fetches = {}
        if mode == 'depth':
            fetches['depth'] = self.pred_depth
        if mode == 'pose':
            fetches['pose'] = self.pred_poses
        results = sess.run(fetches, feed_dict={self.inputs:inputs})
        return results

    def save(self, sess, checkpoint_dir, step):
        model_name = 'model'
        print(" [*] Saving checkpoint to %s..." % checkpoint_dir)
        if step == 'latest':
            self.saver.save(sess, 
                            os.path.join(checkpoint_dir, model_name + '.latest'))
        else:
            self.saver.save(sess, 
                            os.path.join(checkpoint_dir, model_name),
                            global_step=step)
